# 引入必要的库
import gradio as gr  # 用于创建交互式界面
import numpy as np  # 用于数组操作,简称为np
import torch  # 用于深度学习模型的加载和运行
from model import StyledGenerator  # 从model.py中导入StyledGenerator类
import math  # 用于数学计算

######## 1. 设置设备和加载预训练模型 ########
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设定设备,如果有CUDA支持则使用GPU,否则使用CPU
size = 1024  # 设置生成图片的大小为1024x1024
netG = StyledGenerator(512).to(device)  # 加载预训练的StyleGAN模型,并将其移动到设备上
netG.load_state_dict(torch.load('model/stylegan-1024px-new.model', map_location=device)["g_running"], strict=False)  # 加载预训练模型的权重
netG.eval()  # 将模型设置为评估模式

######## 2. 加载修改人脸属性的方向向量 ########
#加载年纪
age_direction = np.load('latent_directions/age.npy')# 加载年龄方向向量
#加载微笑程度
smile_direction = np.load('latent_directions/smile.npy')  # 加载微笑方向向量

######## 3. 定义辅助函数 ########
# 将一个表示图像的PyTorch张量的范围从任意范围转化到-1~1，然后再转化到0~1，然后再转化到0~255，然后reshape成一个图片
def make_image(tensor):
    """将张量转换为图片的函数"""
    return (
        tensor.detach()  # 将张量从计算图中分离
        .clamp_(min=-1, max=1)  # 将张量的值限制在[-1,1]范围内
        .add(1)  # 将张量的值加1,使其范围变为[0,2]
        .div_(2)  # 将张量的值除以2,使其范围变为[0,1]
        .mul(255)  # 将张量的值乘以255,使其范围变为[0,255]
        .type(torch.uint8)  # 将张量的类型转换为uint8
        .permute(0, 2, 3, 1)  # 将张量的维度顺序从[batch_size,channels,height,width]变为[batch_size,height,width,channels]
        .to("cpu")  # 将张量移动到CPU上
        .numpy()  # 将张量转换为numpy数组
    )

def transform_face(image, age: float, smile: float):
    """根据性别、年龄和微笑程度修改人脸的函数"""
    latent = torch.from_numpy(np.load("inference/yong_woman_plane_face.npy")).to(device)  # 1. 加载女孩向量
    direction = torch.from_numpy(age_direction).to(device)  # 2. 加载年龄方向向量
    #加载微笑方向向量
    direction2 = torch.from_numpy(smile_direction).to(device)
    
    latent += age * direction  # 将潜在向量沿着对应的方向移动, age是控制调节幅度的滑块的值
    latent += smile * direction2  # 将潜在向量沿着对应的方向移动, smile是控制调节幅度的滑块的值
        
    step = int(math.log(size, 2)) - 2  # 计算生成图片所需的步数
    img_gen = netG([latent], mean_style=latent, step=step, style_weight=0.3)  # 生成图片
    img = make_image(img_gen)[0]  # 将生成的张量转换为图片
    return img

######## 4. 定义Gradio接口 ########
iface = gr.Interface(
    fn=transform_face,  # 指定要调用的函数
    inputs=[
        gr.Image(),  # 输入图片,大小为1024x1024
        gr.Slider(minimum=-0.12, maximum=0.12, step=0.001,label="Age"),  # 年龄滑块,范围为[-0.12,0.12],默认值为0
        gr.Slider(minimum=-0.12, maximum=0.12, step=0.001,label="Smile")  # 微笑滑块,范围为[-0.12,0.12],默认值为0
    ], 
    outputs='image',  # 输出图片
    examples=[['inference/yong_woman_plane_face.png',0,0]]  # 示例输入
)
iface.launch()  # 启动Gradio界面
